{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "de5fe140-f950-4f55-a69d-32b200ad8a7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/conda/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...\n",
      "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n",
      "CUDA SETUP: Detected CUDA version 117\n",
      "CUDA SETUP: Loading binary /usr/local/conda/lib/python3.10/site-packages/bitsandbytes-0.37.2-py3.10.egg/bitsandbytes/libbitsandbytes_cuda117.so...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/conda/lib/python3.10/site-packages/bitsandbytes-0.37.2-py3.10.egg/bitsandbytes/cuda_setup/main.py:136: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/nvidia/lib'), PosixPath('/usr/local/cuda/extras/CUPTI/lib64'), PosixPath('/usr/local/nvidia/lib64')}\n",
      "  warn(msg)\n",
      "/usr/local/conda/lib/python3.10/site-packages/bitsandbytes-0.37.2-py3.10.egg/bitsandbytes/cuda_setup/main.py:136: UserWarning: /usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/extras/CUPTI/lib64 did not contain libcudart.so as expected! Searching further paths...\n",
      "  warn(msg)\n",
      "/usr/local/conda/lib/python3.10/site-packages/bitsandbytes-0.37.2-py3.10.egg/bitsandbytes/cuda_setup/main.py:136: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('module'), PosixPath('//matplotlib_inline.backend_inline')}\n",
      "  warn(msg)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import bitsandbytes as bnb\n",
    "from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM\n",
    "\n",
    "device=\"cuda\"\n",
    "model = AutoModelForCausalLM.from_pretrained(\"/workspace/model/bloomz-3b\", load_in_8bit=True)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"/workspace/model/bloomz-3b\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "836e6354-38e2-471c-966e-71614d35c5dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/conda/lib/python3.10/site-packages/peft/utils/other.py:136: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from peft import prepare_model_for_int8_training\n",
    "model = prepare_model_for_int8_training(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45138378-5788-4cbd-a938-d2d969f56c26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 4915200 || all params: 3007472640 || trainable%: 0.1634329082375293\n"
     ]
    }
   ],
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
    "    )\n",
    "\n",
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    r=16, lora_alpha=32, target_modules=[\"query_key_value\"], lora_dropout=0.05, bias=\"none\", task_type=\"CAUSAL_LM\"\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, config)\n",
    "model = model.to(device)\n",
    "print_trainable_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b482dbfd-95dd-47a6-8762-124d613756dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "from datasets import load_dataset\n",
    "\n",
    "# data = load_dataset(\"Abirate/english_quotes\")\n",
    "\n",
    "data = load_dataset(\"/workspace/data/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96\")\n",
    "data = data.map(lambda samples: tokenizer(samples[\"quote\"]), batched=True)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0de2b961-d394-422c-8ef3-3c55893dec2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a BloomTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/usr/local/conda/lib/python3.10/site-packages/bitsandbytes-0.37.2-py3.10.egg/bitsandbytes/autograd/_functions.py:298: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='20' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [20/20 01:55, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.211400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.373700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>3.170300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>3.282100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.187100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>3.196800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.952900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>3.212800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>3.168700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>3.116700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>3.131100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>2.964200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>2.955700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>3.010200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>3.048800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>3.031200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>2.885200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>3.000700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>2.994000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>3.083600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "checkpoint folder:  outputs/checkpoint-10\n",
      "checkpoint folder list:  ['README.md', 'adapter_config.json', 'adapter_model', 'adapter_model.safetensors', 'optimizer.pt', 'rng_state.pth', 'scheduler.pt', 'trainer_state.json', 'training_args.bin']\n",
      "checkpoint adapter folder list:  ['README.md', 'adapter_config.json', 'adapter_model.bin']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/conda/lib/python3.10/site-packages/bitsandbytes-0.37.2-py3.10.egg/bitsandbytes/autograd/_functions.py:298: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "checkpoint folder:  outputs/checkpoint-20\n",
      "checkpoint folder list:  ['README.md', 'adapter_config.json', 'adapter_model', 'adapter_model.safetensors', 'optimizer.pt', 'rng_state.pth', 'scheduler.pt', 'trainer_state.json', 'training_args.bin']\n",
      "checkpoint adapter folder list:  ['README.md', 'adapter_config.json', 'adapter_model.bin']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=20, training_loss=3.098853278160095, metrics={'train_runtime': 122.735, 'train_samples_per_second': 10.429, 'train_steps_per_second': 0.163, 'total_flos': 2954973215784960.0, 'train_loss': 3.098853278160095, 'epoch': 0.51})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl\n",
    "from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n",
    "\n",
    "class SavePeftModelCallback(TrainerCallback):\n",
    "    def on_save(\n",
    "        self,\n",
    "        args: TrainingArguments,\n",
    "        state: TrainerState,\n",
    "        control: TrainerControl,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        checkpoint_folder = os.path.join(args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\")\n",
    "        print(\"checkpoint folder: \",checkpoint_folder)\n",
    "        peft_model_path = os.path.join(checkpoint_folder, \"adapter_model\")\n",
    "        kwargs[\"model\"].save_pretrained(peft_model_path)\n",
    "\n",
    "        \n",
    "        files = os.listdir(checkpoint_folder)\n",
    "        print(\"checkpoint folder list: \", files)\n",
    "        adapter_files = os.listdir(peft_model_path)\n",
    "        print(\"checkpoint adapter folder list: \", adapter_files)\n",
    "        \n",
    "        pytorch_model_path = os.path.join(checkpoint_folder, \"pytorch_model.bin\")\n",
    "        if os.path.exists(pytorch_model_path):\n",
    "            os.remove(pytorch_model_path)\n",
    "        return control\n",
    "\n",
    "args = transformers.TrainingArguments(\n",
    "        per_device_train_batch_size=2,\n",
    "        gradient_accumulation_steps=4,\n",
    "        warmup_steps=5,\n",
    "        max_steps=20,\n",
    "        learning_rate=2e-4,\n",
    "        fp16=True,\n",
    "        logging_steps=1,\n",
    "        output_dir=\"outputs\",\n",
    "        save_strategy = 'steps',\n",
    "        save_steps = 10\n",
    "    )\n",
    "\n",
    "trainer = transformers.Trainer(\n",
    "    model=model,\n",
    "    train_dataset=data[\"train\"],\n",
    "    args=args,\n",
    "    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    "    callbacks=[SavePeftModelCallback()],\n",
    ")\n",
    "\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "44cf46c9-dbaf-45f0-9a05-694ce53b74f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1591: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
      "  warnings.warn(\n",
      "/usr/local/conda/lib/python3.10/site-packages/bitsandbytes-0.37.2-py3.10.egg/bitsandbytes/autograd/_functions.py:298: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output：\n",
      "\n",
      " Two things are infinite:  the universe and the number of ways you can screw up.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "peft_model_id = \"outputs/checkpoint-20/\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map=\"auto\"\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
    "\n",
    "# Load the Lora model\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)\n",
    "\n",
    "batch = tokenizer(\"Two things are infinite: \", return_tensors=\"pt\")\n",
    "\n",
    "with torch.cuda.amp.autocast():\n",
    "    output_tokens = model.generate(**batch, max_new_tokens=50)\n",
    "\n",
    "print(\"output：\\n\\n\", tokenizer.decode(output_tokens[0], skip_special_tokens=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
